{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/cookbooks/mistralai.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MistralAI Cookbook\n",
    "\n",
    "MistralAI released [mistral-large](https://mistral.ai/news/mistral-large/) model with enhancing capabilities of Function calling, reasoning, precise instruction-following, JSON mode and many more.\n",
    "\n",
    "This is a cook-book in showcasing the usage of `mistral-large` model with llama-index."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup LLM and Embedding Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()\n",
    "\n",
    "import os\n",
    "\n",
    "os.environ[\"MISTRAL_API_KEY\"] = \"YOUR MISTRALAI API KEY\"\n",
    "\n",
    "from llama_index.llms.mistralai import MistralAI\n",
    "from llama_index.embeddings.mistralai import MistralAIEmbedding\n",
    "from llama_index.core import Settings\n",
    "\n",
    "llm = MistralAI(model=\"mistral-large\", temperature=0.1)\n",
    "embed_model = MistralAIEmbedding(model_name=\"mistral-embed\")\n",
    "\n",
    "Settings.llm = llm\n",
    "Settings.embed_model = embed_model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download Data\n",
    "\n",
    "We will use `Uber-2021` and `Lyft-2021` 10K SEC filings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-02-27 01:17:30--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1880483 (1.8M) [application/octet-stream]\n",
      "Saving to: './uber_2021.pdf'\n",
      "\n",
      "./uber_2021.pdf     100%[===================>]   1.79M  7.16MB/s    in 0.3s    \n",
      "\n",
      "2024-02-27 01:17:31 (7.16 MB/s) - './uber_2021.pdf' saved [1880483/1880483]\n",
      "\n",
      "--2024-02-27 01:17:31--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1440303 (1.4M) [application/octet-stream]\n",
      "Saving to: './lyft_2021.pdf'\n",
      "\n",
      "./lyft_2021.pdf     100%[===================>]   1.37M  2.34MB/s    in 0.6s    \n",
      "\n",
      "2024-02-27 01:17:32 (2.34 MB/s) - './lyft_2021.pdf' saved [1440303/1440303]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' -O './uber_2021.pdf'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' -O './lyft_2021.pdf'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SimpleDirectoryReader\n",
    "\n",
    "uber_docs = SimpleDirectoryReader(input_files=[\"./uber_2021.pdf\"]).load_data()\n",
    "lyft_docs = SimpleDirectoryReader(input_files=[\"./lyft_2021.pdf\"]).load_data()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build RAG on uber docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The revenue of Uber in 2021 was $17,455 million.\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "uber_index = VectorStoreIndex.from_documents(uber_docs)\n",
    "uber_query_engine = uber_index.as_query_engine(similarity_top_k=5)\n",
    "\n",
    "response = uber_query_engine.query(\"What is the revenue of uber in 2021?\")\n",
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare `Uber` and `Lyft` revenue\n",
    "\n",
    "We will use `SubQuestionQueryEngine`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lyft_index = VectorStoreIndex.from_documents(lyft_docs)\n",
    "lyft_query_engine = lyft_index.as_query_engine(similarity_top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.tools import QueryEngineTool, ToolMetadata\n",
    "from llama_index.core.query_engine import SubQuestionQueryEngine\n",
    "\n",
    "query_engine_tools = [\n",
    "    QueryEngineTool(\n",
    "        query_engine=lyft_query_engine,\n",
    "        metadata=ToolMetadata(\n",
    "            name=\"lyft_10k\",\n",
    "            description=\"Provides information about Lyft financials for year 2021\",\n",
    "        ),\n",
    "    ),\n",
    "    QueryEngineTool(\n",
    "        query_engine=uber_query_engine,\n",
    "        metadata=ToolMetadata(\n",
    "            name=\"uber_10k\",\n",
    "            description=\"Provides information about Uber financials for year 2021\",\n",
    "        ),\n",
    "    ),\n",
    "]\n",
    "\n",
    "sub_question_query_engine = SubQuestionQueryEngine.from_defaults(\n",
    "    query_engine_tools=query_engine_tools\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated 4 sub questions.\n",
      "\u001b[1;3;38;2;237;90;200m[uber_10k] Q: What was the revenue of Uber in 2020\n",
      "\u001b[0m\u001b[1;3;38;2;90;149;237m[uber_10k] Q: What was the revenue of Uber in 2021\n",
      "\u001b[0m\u001b[1;3;38;2;11;159;203m[lyft_10k] Q: What was the revenue of Lyft in 2020\n",
      "\u001b[0m\u001b[1;3;38;2;155;135;227m[lyft_10k] Q: What was the revenue of Lyft in 2021\n",
      "\u001b[0m\u001b[1;3;38;2;155;135;227m[lyft_10k] A: The revenue of Lyft in 2021 was $3,208,323.\n",
      "\u001b[0m\u001b[1;3;38;2;90;149;237m[uber_10k] A: The revenue of Uber in 2021 was $17,455 million.\n",
      "\u001b[0m\u001b[1;3;38;2;11;159;203m[lyft_10k] A: The revenue of Lyft in 2020 was $2,364,681 (in thousands).\n",
      "\u001b[0m\u001b[1;3;38;2;237;90;200m[uber_10k] A: The revenue of Uber in 2020 was $11,139 million.\n",
      "\u001b[0mFrom 2020 to 2021, both Uber and Lyft experienced revenue growth. Uber's revenue increased from $11,139 million in 2020 to $17,455 million in 2021. On the other hand, Lyft's revenue grew from $2,364,681 (in thousands) in 2020 to $3,208,323 in 2021. This indicates that both companies had a positive growth trajectory in their revenues during this period.\n"
     ]
    }
   ],
   "source": [
    "response = await sub_question_query_engine.aquery(\n",
    "    \"Compare revenue growth of Uber and Lyft from 2020 to 2021\"\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Route queries between `Uber` and `Lyft`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SummaryIndex\n",
    "\n",
    "summary_index = SummaryIndex.from_documents(uber_docs)\n",
    "summary_query_engine = summary_index.as_query_engine(\n",
    "    response_mode=\"tree_summarize\",\n",
    "    use_async=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.tools import QueryEngineTool\n",
    "\n",
    "lyft_vector_tool = QueryEngineTool.from_defaults(\n",
    "    query_engine=lyft_query_engine,\n",
    "    description=(\n",
    "        \"Useful for retrieving specific context from lyft 10k SEC filings of 2021\"\n",
    "    ),\n",
    ")\n",
    "\n",
    "uber_vector_tool = QueryEngineTool.from_defaults(\n",
    "    query_engine=uber_query_engine,\n",
    "    description=(\n",
    "        \"Useful for retrieving specific context from uber 10k SEC filings of 2021\"\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import RouterQueryEngine\n",
    "from llama_index.core.selectors import LLMSingleSelector\n",
    "\n",
    "router_query_engine = RouterQueryEngine(\n",
    "    selector=LLMSingleSelector.from_defaults(),\n",
    "    query_engine_tools=[\n",
    "        lyft_vector_tool,\n",
    "        uber_vector_tool,\n",
    "    ],\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;3;38;5;200mSelecting query engine 1: This choice is relevant because it pertains to Uber's 10k SEC filings of 2021, where the revenue information for the year is likely to be found..\n",
      "\u001b[0mThe revenue of Uber in 2021 was $17,455 million.\n"
     ]
    }
   ],
   "source": [
    "response = router_query_engine.query(\"What is the revenue of uber in 2021?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is most relevant to the question as it pertains to retrieving specific context from Lyft's 10k SEC filings of 2021, where information about Lyft's investments made in 2021 would likely be found..\n",
      "\u001b[0mIn 2021, Lyft made several investments to improve and expand their services. They continued to invest in the expansion of their network of Light Vehicles and Lyft Autonomous, which focuses on the deployment and scaling of third-party self-driving technology on the Lyft network. They also invested in their Express Drive program, which provides drivers access to rental cars they can use for ridesharing. Additionally, they made investments in their Driver Centers, Mobile Services, and related partnerships that offer drivers affordable and convenient vehicle maintenance services. Furthermore, they invested in their proprietary technology, including mapping, routing, payments, in-app navigation, matching technologies, and data science to make their network more efficient and seamless to use. They also acquired certain money market deposit accounts and cash in transit from payment processors for credit and debit card transactions. Short-term investments consisted of commercial paper, certificates of deposit, corporate bonds, and term deposits, which mature in 12 months or less. Restricted cash, cash equivalents, and investments consisted primarily of amounts held in separate trust accounts and restricted bank accounts as collateral for insurance purposes and amounts pledged to secure certain letters of credit.\n"
     ]
    }
   ],
   "source": [
    "response = router_query_engine.query(\n",
    "    \"What are the investments made by lyft in 2021?\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tools usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.agent import ReActAgent\n",
    "from llama_index.core.tools import FunctionTool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiple two integers and returns the result integer\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n",
    "\n",
    "\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers and returns the result integer\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "add_tool = FunctionTool.from_defaults(fn=add)\n",
    "\n",
    "tools = [multiply_tool, add_tool]\n",
    "agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;3;38;5;200mThought: I need to use a tool to multiply 26 and 2.\n",
      "Action: multiply\n",
      "Action Input: {'a': 26, 'b': 2}\n",
      "\u001b[0m\u001b[1;3;34mObservation: 52\n",
      "\u001b[0m\u001b[1;3;38;5;200mThought: I need to use a tool to add the result of the multiplication to 2024.\n",
      "Action: add\n",
      "Action Input: {'a': 52, 'b': 2024}\n",
      "\u001b[0m\u001b[1;3;34mObservation: 2076\n",
      "\u001b[0m\u001b[1;3;38;5;200mThought: I can answer without using any more tools.\n",
      "Answer: The result of (26 * 2) + 2024 is 2076.\n",
      "\u001b[0mThe result of (26 * 2) + 2024 is 2076.\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"What is (26 * 2) + 2024?\")\n",
    "print(response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
